import random
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, build_optimizer)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook, Fp16OptimizerHook
from mmdet.datasets import build_dataloader, build_dataset
from mmdet.utils import get_root_logger


def set_random_seed(seed, deterministic=False):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def train_detector(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None):
    if type(cfg) != list:
        logger = get_root_logger(cfg.log_level)
        # prepare data loaders
        dataset_L = dataset if isinstance(dataset, (list, tuple)) else [dataset]
        if 'imgs_per_gpu' in cfg.data:
            logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. Please use "samples_per_gpu" instead')
            if 'samples_per_gpu' in cfg.data:
                logger.warning(
                    f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and "samples_per_gpu"={cfg.data.samples_per_gpu},'
                    f' "imgs_per_gpu"={cfg.data.imgs_per_gpu} is used in this experiments')
            else:
                logger.warning(
                    f'Automatically set "samples_per_gpu"="imgs_per_gpu"={cfg.data.imgs_per_gpu} in this experiments')
            cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
        data_loaders_L = [build_dataloader(ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, len(cfg.gpu_ids),
                                         dist=distributed, seed=cfg.seed) for ds in dataset_L]
        data_loaders_U = None
    else:
        cfg = cfg[0]  # config used in this file are the same for cfg and cfg_u
        logger = get_root_logger(cfg.log_level)
        # prepare data loaders
        dataset_U = dataset[1]
        dataset_L = dataset[0]
        if 'imgs_per_gpu' in cfg.data:
            logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. Please use "samples_per_gpu" instead')
            if 'samples_per_gpu' in cfg.data:
                logger.warning(
                    f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and "samples_per_gpu"={cfg.data.samples_per_gpu},'
                    f' "imgs_per_gpu"={cfg.data.imgs_per_gpu} is used in this experiments')
            else:
                logger.warning(
                    f'Automatically set "samples_per_gpu"="imgs_per_gpu"= {cfg.data.imgs_per_gpu} in this experiments')
            cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
        data_loaders_L = [build_dataloader(ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, len(cfg.gpu_ids),
                                         dist=distributed, seed=cfg.seed) for ds in dataset_L]
        data_loaders_U = [build_dataloader(ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, len(cfg.gpu_ids),
                                           dist=distributed, seed=cfg.seed) for ds in dataset_U]
    # put model on gpus
    if distributed:
        # find_unused_parameters = True # cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(model.cuda(), device_ids=[torch.cuda.current_device()],
                                          broadcast_buffers=False, find_unused_parameters=True)
    else:
        model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
    # build runner
    params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(params, lr=cfg.optimizer.lr, momentum=0.9, weight_decay=0.0005)
    runner = EpochBasedRunner(model, optimizer=optimizer, work_dir=cfg.work_directory, logger=logger, meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp
    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        val_dataloader = build_dataloader(val_dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu,
                                          dist=distributed, shuffle=False)
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), f'Each item in custom_hooks expects dict type, but got {type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    if data_loaders_U is None:
        runner.run(data_loaders_L, cfg.workflow, cfg.total_epochs)
    else:
        runner.run([data_loaders_L, data_loaders_U], cfg.workflow, cfg.total_epochs)
